Skip to content

[Bugfix] (qwen3_tts): enable batched offline inference by fixing tens…#1417

Open
RomanKoshkin wants to merge 7 commits intovllm-project:mainfrom
RomanKoshkin:batched_inference_fix
Open

[Bugfix] (qwen3_tts): enable batched offline inference by fixing tens…#1417
RomanKoshkin wants to merge 7 commits intovllm-project:mainfrom
RomanKoshkin:batched_inference_fix

Conversation

@RomanKoshkin
Copy link

Purpose

This PR fixes two bugs in the Qwen3-TTS model wrapper that prevented offline batched inference (max_batch_size > 1) from working correctly. Previously, passing a batch of inputs resulted in the engine either broadcasting the first request's output to all requests or crashing the worker due to length mismatches.

Specifically, this PR addresses:

  1. Input Truncation: In Qwen3TTSModelForGeneration.forward(), runtime_additional_information was hardcoded to pop index [0]. This has been updated to iterate over the runtime_info_list and properly accumulate batched inputs (texts, speakers, languages, ref_audio, etc.) into lists before passing them to the underlying generation methods.
  2. Output Tensor Slicing: In make_omni_output(), the batched audio_tensors array was hardcoded to only convert and return audio_tensors[0]. This has been fixed to convert the entire list of arrays to PyTorch tensors and return the full list in multimodal_outputs, allowing gpu_generation_model_runner.py to correctly map unique audio tensors back to their respective request IDs.

Test Plan

Tested offline batched inference using omni.generate() with a batch size of 80 requests.

Since this is a fix to the core multimodal routing logic, no new automated tests are required. However, the fix can be manually verified using this minimal offline batching script:

Test Script
import os
from typing import NamedTuple  # noqa: UP035

import soundfile as sf

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from vllm import SamplingParams

from vllm_omni import Omni


class QueryResult(NamedTuple):
    """Container for a prepared Omni request."""

    inputs: dict
    model_name: str

# new
def get_base_query(
    ref_audios: list[str],
    ref_texts: list[str],
    target_texts: list[str],
    target_langs: list[str],
):

    inputs = []
    for target_text, target_lang, ref_audio, ref_text in zip(
        target_texts,
        target_langs,
        ref_audios,
        ref_texts,
    ):
        prompt = f"<|im_start|>assistant\n{target_text}<|im_end|>\n<|im_start|>assistant\n"
        print(prompt)
        inputs.append(
            {
                "prompt": prompt,
                "additional_information": {
                    "task_type": ["Base"],
                    "ref_audio": [ref_audio],
                    "ref_text": [ref_text],
                    "text": [target_text],
                    "language": [target_lang],
                    "x_vector_only_mode": [False],
                    "max_new_tokens": [8192],
                },
            }
        )

    return QueryResult(
        inputs=inputs,
        model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
    )

def main():

    omni = Omni(
        model="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
        stage_configs_path="vllm_omni/model_executor/stage_configs/qwen3_tts.yaml",
        log_stats=True,
        stage_ibnit_timeout=300,
    )


    target_texts = [
        'Welcome to another episode of Out of the Pods.',
        "I'm Deep T. And I'm Natalie.",
        'And happy Wednesday.',
        'You know, we said last week that this episode is going to be about our recap of Perfect Match Season 2, Episodes 1 through 6, which we will get into.',
        'Lots of thoughts.',
        'Actually, almost no thoughts because...',
        'This is not a great season.',
        "It's just not off to a good start.",
        'I feel like I lost some brain cells watching it.',
        'Oh, 100%.'
    ]
    ref_audios = ["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"] * len(target_texts)
    ref_texts = [
       "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you.",
    ] * len(target_texts)
    target_langs = ["English"] * len(target_texts)


    query_result = get_base_query(ref_audios, ref_texts, target_texts, target_langs)

    sampling_params = SamplingParams(
        temperature=0.9,
        top_p=1.0,
        top_k=50,
        max_tokens=8192,
        seed=42,
        detokenize=False,
        repetition_penalty=1.05,
    )

    sampling_params_list = [
        sampling_params,
    ]

    output_dir = "vllm-omni/examples/offline_inference/qwen3_tts/output"
    os.makedirs(output_dir, exist_ok=True)

    omni_generator = omni.generate(query_result.inputs, sampling_params_list)
    for stage_outputs in omni_generator:
        for output in stage_outputs.request_output:
            request_id = output.request_id
            audio_tensor = output.outputs[0].multimodal_output["audio"].clone()
            print(f"audio_tensor: {audio_tensor.shape}")
            output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
            audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
            # Convert to numpy array and ensure correct format
            audio_numpy = audio_tensor.float().detach().cpu().numpy()

            # Ensure audio is 1D (flatten if needed)
            if audio_numpy.ndim > 1:
                audio_numpy = audio_numpy.flatten()

            # Save audio file with explicit WAV format
            sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
            print(f"Request ID: {request_id}, Saved audio to {output_wav}")

if __name__ == "__main__":
    main()

…or slicing

Signed-off-by: Roman Koshkin <roman.koshkin@sbintuitions.co.jp>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6dd5ad26bd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution — the core batch fix is correct. A few comments on cleanup needed before this lands.

…xing tensor slicing

Signed-off-by: Roman Koshkin <roman.koshkin@sbintuitions.co.jp>
@RomanKoshkin
Copy link
Author

@lishunyang12 Thanks for the review. I've addressed almost all the comments (except the NPU one, as I can't properly test it on my side). I've re-run the tests for offline_inference/qwen3_tts/end2end.py and it works normally.

@RomanKoshkin
Copy link
Author

@lishunyang12 I've also pushed the updated code to another branch.

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup — most comments addressed. A couple things still open.

RomanKoshkin and others added 3 commits February 21, 2026 16:22
…ixing tensor slicing

Signed-off-by: Roman Koshkin <roman.koshkin@sbintuitions.co.jp>
…ixing tensor slicing

Signed-off-by: Roman Koshkin <roman.koshkin@sbintuitions.co.jp>
Signed-off-by: Roman Koshkin <26285991+RomanKoshkin@users.noreply.github.com>
@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Feb 23, 2026
@hsliuustc0106
Copy link
Collaborator

can you add some performance tests here?

@hsliuustc0106
Copy link
Collaborator

@tzhouam qwen-tts also uses additional_info

@RomanKoshkin
Copy link
Author

RomanKoshkin commented Feb 23, 2026

can you add some performance tests here?
@hsliuustc0106


Device: Nvidia A100 80GB

Test inputs: 40 sentences (varying length, English)

#### Generating one-by-one, via the API

Wall time: 6min 56s  online batch_size=1 

#### Offline batched (batch_size=40)

Wall time: 50.3 s    offline_batched

[Overall Summary]👇
+-----------------------------+------------+
| Field                       |      Value |
+-----------------------------+------------+
| e2e_requests                |         40 |
| e2e_wall_time_ms            | 50,329.108 |
| e2e_total_tokens            |      1,166 |
| e2e_avg_time_per_request_ms |  1,258.228 |
| e2e_avg_tokens_per_s        |     23.168 |
| e2e_stage_0_wall_time_ms    | 50,319.723 |
+-----------------------------+------------+

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now — all the previous concerns are addressed. The list-of-tensors output in make_omni_output is handled correctly by the generation model runner's dict branch. Nice work on the batched inference fix.

Copy link
Contributor

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Core batching logic is correct. Previous review feedback has been addressed.

A few minor (non-blocking) suggestions:

  1. Guard against empty runtime_info_list: If runtime_info_list is [], task_types[0] will raise an uncontrolled IndexError. A quick early check would be cleaner.

  2. Warn on inconsistent non-batched kwargs: For scalar params like max_new_tokens, only the first request's value is used silently. Consider logging a warning if values differ across requests.

  3. Use a set for key filtering: if k not in ["text", "task_type", ...]if k not in {"text", "task_type", ...} (minor, but cleaner).

  4. Extract batched_keys as a class constant: The set {"ref_audio", "ref_text", ...} encodes model-specific knowledge — promoting it to a class-level constant with a brief comment improves maintainability.

@hsliuustc0106
Copy link
Collaborator

@vllm-omni-reviewer

Gaohan123 and others added 2 commits February 26, 2026 10:23
Signed-off-by: Roman Koshkin <26285991+RomanKoshkin@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants